/*    Copyright 2014 MongoDB Inc.
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

#include "mongo/platform/basic.h"

#include "mongo/util/net/sock.h"

#include <boost/thread.hpp>

#ifndef _WIN32
#include <netdb.h>
#include <sys/socket.h>
#include <sys/types.h>
#endif

#include "mongo/unittest/unittest.h"
#include "mongo/util/concurrency/synchronization.h"
#include "mongo/util/fail_point_service.h"
#include "mongo/util/log.h"

namespace {

    using namespace mongo;

    typedef boost::shared_ptr<Socket> SocketPtr;
    typedef std::pair<SocketPtr, SocketPtr> SocketPair;

    // On UNIX, make a connected pair of PF_LOCAL (aka PF_UNIX) sockets via the native 'socketpair'
    // call. The 'type' parameter should be one of SOCK_STREAM, SOCK_DGRAM, SOCK_SEQPACKET, etc.
    // For Win32, we don't have a native socketpair function, so we hack up a connected PF_INET
    // pair on a random port.
    SocketPair socketPair(const int type, const int protocol = 0);

#if defined(_WIN32)
    namespace detail {
        void awaitAccept(SOCKET* acceptSock, SOCKET listenSock, Notification& notify) {
            *acceptSock = INVALID_SOCKET;
            const SOCKET result = ::accept(listenSock, NULL, 0);
            if (result != INVALID_SOCKET) {
                *acceptSock = result;
            }
            notify.notifyOne();
        }

        void awaitConnect(SOCKET* connectSock, const struct addrinfo& where, Notification& notify) {
            *connectSock = INVALID_SOCKET;
            SOCKET newSock = ::socket(where.ai_family, where.ai_socktype, where.ai_protocol);
            if (newSock != INVALID_SOCKET) {
                int result = ::connect(newSock, where.ai_addr, where.ai_addrlen);
                if (result == 0) {
                    *connectSock = newSock;
                }
            }
            notify.notifyOne();
        }
    } // namespace detail

    SocketPair socketPair(const int type, const int protocol) {

        const int domain = PF_INET;

        // Create a listen socket and a connect socket.
        const SOCKET listenSock = ::socket(domain, type, protocol);
        if (listenSock == INVALID_SOCKET)
            return SocketPair();

        // Bind the listen socket on port zero, it will pick one for us, and start it listening
        // for connections.
        struct addrinfo hints, *res;
        ::memset(&hints, 0, sizeof(hints));
        hints.ai_family = PF_INET;
        hints.ai_socktype = type;
        hints.ai_flags = AI_PASSIVE;

        int result = ::getaddrinfo(NULL, "0", &hints, &res);
        if (result != 0) {
            closesocket(listenSock);
            return SocketPair();
        }

        result = ::bind(listenSock, res->ai_addr, res->ai_addrlen);
        if (result != 0) {
            closesocket(listenSock);
            ::freeaddrinfo(res);
            return SocketPair();
        }

        // Read out the port to which we bound.
        sockaddr_in bindAddr;
        ::socklen_t len = sizeof(bindAddr);
        ::memset(&bindAddr, 0, sizeof(bindAddr));
        result = ::getsockname(listenSock, reinterpret_cast<struct sockaddr*>(&bindAddr), &len);
        if (result != 0) {
            closesocket(listenSock);
            ::freeaddrinfo(res);
            return SocketPair();
        }

        result = ::listen(listenSock, 1);
        if (result != 0) {
            closesocket(listenSock);
            ::freeaddrinfo(res);
            return SocketPair();
        }

        struct addrinfo connectHints, *connectRes;
        ::memset(&connectHints, 0, sizeof(connectHints));
        connectHints.ai_family = PF_INET;
        connectHints.ai_socktype = type;
        std::stringstream portStream;
        portStream << ntohs(bindAddr.sin_port);
        result = ::getaddrinfo(NULL, portStream.str().c_str(), &connectHints, &connectRes);
        if (result != 0) {
            closesocket(listenSock);
            ::freeaddrinfo(res);
            return SocketPair();
        }

        // I'd prefer to avoid trying to do this non-blocking on Windows. Just spin up some
        // threads to do the connect and acccept.

        Notification accepted;
        SOCKET acceptSock = INVALID_SOCKET;
        boost::thread acceptor(
            stdx::bind(&detail::awaitAccept, &acceptSock, listenSock, boost::ref(accepted)));

        Notification connected;
        SOCKET connectSock = INVALID_SOCKET;
        boost::thread connector(
            stdx::bind(&detail::awaitConnect, &connectSock, *connectRes, boost::ref(connected)));

        connected.waitToBeNotified();
        connector.join();
        if (connectSock == INVALID_SOCKET) {
            closesocket(listenSock);
            ::freeaddrinfo(res);
            ::freeaddrinfo(connectRes);
            closesocket(acceptSock);
            closesocket(connectSock);
            acceptor.join();
            return SocketPair();
        }

        accepted.waitToBeNotified();
        acceptor.join();
        if (acceptSock == INVALID_SOCKET) {
            closesocket(listenSock);
            ::freeaddrinfo(res);
            ::freeaddrinfo(connectRes);
            closesocket(acceptSock);
            closesocket(connectSock);
            return SocketPair();
        }

        closesocket(listenSock);
        ::freeaddrinfo(res);
        ::freeaddrinfo(connectRes);

        SocketPtr first(new Socket(static_cast<int>(acceptSock), SockAddr()));
        SocketPtr second(new Socket(static_cast<int>(connectSock), SockAddr()));

        return SocketPair(first, second);
    }
#else
    // We can just use ::socketpair and wrap up the result in a Socket.
    SocketPair socketPair(const int type, const int protocol) {
        // PF_LOCAL is the POSIX name for Unix domain sockets, while PF_UNIX
        // is the name that BSD used.  We use the BSD name because it is more
        // widely supported (e.g. Solaris 10).
        const int domain = PF_UNIX;

        int socks[2];
        const int result = ::socketpair(domain, type, protocol, socks);
        if (result == 0) {
            return SocketPair(
                SocketPtr(new Socket(socks[0], SockAddr())),
                SocketPtr(new Socket(socks[1], SockAddr())));
        }
        return SocketPair();
    }
#endif

    // This should match the name of the fail point declared in sock.cpp.
    const char kSocketFailPointName[] = "throwSockExcep";

    class SocketFailPointTest : public unittest::Test {
    public:

        SocketFailPointTest() :
            _failPoint(getGlobalFailPointRegistry()->getFailPoint(kSocketFailPointName)),
            _sockets(socketPair(SOCK_STREAM)) {
            //ASSERT_TRUE(_failPoint != NULL);
            //ASSERT_TRUE(_sockets.first);
            //ASSERT_TRUE(_sockets.second);
        }

        bool trySend() {
            char byte = 'x';
            _sockets.first->send(&byte, sizeof(byte), "SocketFailPointTest::trySend");
            return true;
        }

        bool trySendVector() {
            std::vector<std::pair<char*, int> > data;
            char byte = 'x';
            data.push_back(std::make_pair(&byte, sizeof(byte)));
            _sockets.first->send(data, "SocketFailPointTest::trySendVector");
            return true;
        }

        bool tryRecv() {
            char byte;
            _sockets.second->recv(&byte, sizeof(byte));
            return true;
        }

        // You must queue at least one byte on the send socket before calling this function.
        size_t countRecvable(size_t max) {
            std::vector<char> buf(max);
            // This isn't great, because we don't have a guarantee that multiple sends will be
            // captured in one recv. However, sock doesn't let us pass flags into recv, so we
            // can't make this non blocking, and therefore can't risk another call.
            return _sockets.second->unsafe_recv(&buf[0], max);
        }

        FailPoint* const _failPoint;
        const SocketPair _sockets;
    };

    class ScopedFailPointEnabler {
    public:
        ScopedFailPointEnabler(FailPoint& fp)
            : _fp(fp) {
            _fp.setMode(FailPoint::alwaysOn);
        }

        ~ScopedFailPointEnabler() {
            _fp.setMode(FailPoint::off);
        }
    private:
        FailPoint& _fp;
    };

    TEST_F(SocketFailPointTest, TestSend) {
        ASSERT_TRUE(trySend());
        ASSERT_TRUE(tryRecv());
        {
            const ScopedFailPointEnabler enabled(*_failPoint);
            ASSERT_THROWS(trySend(), SocketException);
        }
        // Channel should be working again
        ASSERT_TRUE(trySend());
        ASSERT_TRUE(tryRecv());
    }

    TEST_F(SocketFailPointTest, TestSendVector) {
        ASSERT_TRUE(trySendVector());
        ASSERT_TRUE(tryRecv());
        {
            const ScopedFailPointEnabler enabled(*_failPoint);
            ASSERT_THROWS(trySendVector(), SocketException);
        }
        ASSERT_TRUE(trySendVector());
        ASSERT_TRUE(tryRecv());
    }

    TEST_F(SocketFailPointTest, TestRecv) {
        ASSERT_TRUE(trySend()); // data for recv
        ASSERT_TRUE(tryRecv());
        {
            ASSERT_TRUE(trySend()); // data for recv
            const ScopedFailPointEnabler enabled(*_failPoint);
            ASSERT_THROWS(tryRecv(), SocketException);
        }
        ASSERT_TRUE(trySend()); // data for recv
        ASSERT_TRUE(tryRecv());
    }

    TEST_F(SocketFailPointTest, TestFailedSendsDontSend) {
        ASSERT_TRUE(trySend());
        ASSERT_TRUE(tryRecv());
        {
            ASSERT_TRUE(trySend()); // queue 1 byte
            const ScopedFailPointEnabler enabled(*_failPoint);
            // Fail to queue another byte
            ASSERT_THROWS(trySend(), SocketException);
        }
        // Failed byte should not have been transmitted.
        ASSERT_EQUALS(size_t(1), countRecvable(2));
    }

    // Ensure that calling send doesn't actually enqueue data to the socket
    TEST_F(SocketFailPointTest, TestFailedVectorSendsDontSend) {
        ASSERT_TRUE(trySend());
        ASSERT_TRUE(tryRecv());
        {
            ASSERT_TRUE(trySend()); // queue 1 byte
            const ScopedFailPointEnabler enabled(*_failPoint);
            // Fail to queue another byte
            ASSERT_THROWS(trySendVector(), SocketException);
        }
        // Failed byte should not have been transmitted.
        ASSERT_EQUALS(size_t(1), countRecvable(2));
    }

    TEST_F(SocketFailPointTest, TestFailedRecvsDontRecv) {
        ASSERT_TRUE(trySend());
        ASSERT_TRUE(tryRecv());
        {
            ASSERT_TRUE(trySend());
            const ScopedFailPointEnabler enabled(*_failPoint);
            // Fail to recv that byte
            ASSERT_THROWS(tryRecv(), SocketException);
        }
        // Failed byte should still be queued to recv.
        ASSERT_EQUALS(size_t(1), countRecvable(1));
        // Channel should be working again
        ASSERT_TRUE(trySend());
        ASSERT_TRUE(tryRecv());
    }


} // namespace
